#!/usr/bin/env python
#
# smem - a tool for meaningful memory reporting
#
# Copyright 2008-2009 Matt Mackall <mpm@selenic.com>
#
# This software may be used and distributed according to the terms of
# the GNU General Public License version 2 or later, incorporated
# herein by reference.

import re, os, sys, pwd, optparse, errno, tarfile

warned = False

class procdata(object):
    def __init__(self, source):
        self._ucache = {}
        self._gcache = {}
        self.source = source and source or ""
        self._memdata = None
    def _list(self):
        return os.listdir(self.source + "/proc")
    def _read(self, f):
        return file(self.source + '/proc/' + f).read()
    def _readlines(self, f):
        return self._read(f).splitlines(True)
    def _stat(self, f):
        return os.stat(self.source + "/proc/" + f)

    def pids(self):
        '''get a list of processes'''
        return [int(e) for e in self._list()
                if e.isdigit() and not iskernel(e)]
    def mapdata(self, pid):
        return self._readlines('%s/smaps' % pid)
    def memdata(self):
        if self._memdata is None:
            self._memdata = self._readlines('meminfo')
        return self._memdata
    def version(self):
        return self._readlines('version')[0]
    def pidname(self, pid):
        try:
            l = self._read('%d/stat' % pid)
            return l[l.find('(') + 1: l.find(')')]
        except:
            return '?'
    def pidcmd(self, pid):
        try:
            c = self._read('%s/cmdline' % pid)[:-1]
            return c.replace('\0', ' ')
        except:
            return '?'
    def piduser(self, pid):
        try:
            return self._stat('%d' % pid).st_uid
        except:
            return -1
    def pidgroup(self, pid):
        try:
            return self._stat('%d' % pid).st_gid
        except:
            return -1
    def username(self, uid):
        if uid == -1:
            return '?'
        if uid not in self._ucache:
            try:
                self._ucache[uid] = pwd.getpwuid(uid)[0]
            except KeyError:
                self._ucache[uid] = str(uid)
        return self._ucache[uid]
    def groupname(self, gid):
        if gid == -1:
            return '?'
        if gid not in self._gcache:
            try:
                self._gcache[gid] = pwd.getgrgid(gid)[0]
            except KeyError:
                self._gcache[gid] = str(gid)
        return self._gcache[gid]

class tardata(procdata):
    def __init__(self, source):
        procdata.__init__(self, source)
        self.tar = tarfile.open(source)
    def _list(self):
        for ti in self.tar:
            if ti.name.endswith('/smaps'):
                d,f = ti.name.split('/')
                yield d
    def _read(self, f):
        return self.tar.extractfile(f).read()
    def _readlines(self, f):
        return self.tar.extractfile(f).readlines()
    def piduser(self, p):
        t = self.tar.getmember("%d" % p)
        if t.uname:
            self._ucache[t.uid] = t.uname
        return t.uid
    def pidgroup(self, p):
        t = self.tar.getmember("%d" % p)
        if t.gname:
            self._gcache[t.gid] = t.gname
        return t.gid
    def username(self, u):
        return self._ucache.get(u, str(u))
    def groupname(self, g):
        return self._gcache.get(g, str(g))

_totalmem = 0
def totalmem():
    global _totalmem
    if not _totalmem:
        if options.realmem:
            _totalmem = fromunits(options.realmem) / 1024
        else:
            _totalmem = memory()['memtotal']
    return _totalmem

_kernelsize = 0
def kernelsize():
    global _kernelsize
    if not _kernelsize and options.kernel:
        try:
            d = os.popen("size %s" % options.kernel).readlines()[1]
            _kernelsize = int(d.split()[3]) / 1024
        except:
            try:
                # try some heuristic to find gzipped part in kernel image
                packedkernel = open(options.kernel).read()
                pos = packedkernel.find('\x1F\x8B')
                if pos >= 0 and pos < 25000:
                    sys.stderr.write("Maybe uncompressed kernel can be extracted by the command:\n"
                            "  dd if=%s bs=1 skip=%d | gzip -d >%s.unpacked\n\n" % (options.kernel, pos, options.kernel))
            except:
                pass
            sys.stderr.write("Parameter '%s' should be an original uncompressed compiled kernel file.\n\n" % options.kernel)
    return _kernelsize

def pidmaps(pid):
    global warned
    maps = {}
    start = None
    seen = False
    empty = True
    for l in src.mapdata(pid):
        empty = False
        f = l.split()
        if f[-1] == 'kB':
            if f[0].startswith('Pss'):
                seen = True
            maps[start][f[0][:-1].lower()] = int(f[1])
        elif '-' in f[0] and ':' not in f[0]: # looks like a mapping range
            start, end = f[0].split('-')
            start = int(start, 16)
            name = "<anonymous>"
            if len(f) > 5:
                name = f[5]
            maps[start] = dict(end=int(end, 16), mode=f[1],
                               offset=int(f[2], 16),
                               device=f[3], inode=f[4], name=name)

    if not empty and not seen and not warned:
        sys.stderr.write('warning: kernel does not appear to support PSS measurement\n')
        warned = True
        if not options.sort:
            options.sort = 'rss'

    if options.mapfilter:
        f = {}
        for m in maps:
            if not filter(options.mapfilter, m, lambda x: maps[x]['name']):
                f[m] = maps[m]
        return f

    return maps

def sortmaps(totals, key):
    l = []
    for pid in totals:
        l.append((totals[pid][key], pid))
    l.sort()
    return [pid for pid,key in l]

def iskernel(pid):
    return src.pidcmd(pid) == ""

def memory():
    t = {}
    f = re.compile('(\\S+):\\s+(\\d+) kB')
    for l in src.memdata():
        m = f.match(l)
        if m:
            t[m.group(1).lower()] = int(m.group(2))
    return t

def units(x):
    s = ''
    if x == 0:
        return '0'
    for s in ('', 'K', 'M', 'G', 'T'):
        if x < 1024:
            break
        x /= 1024.0
    return "%.1f%s" % (x, s)

def fromunits(x):
    s = dict(k=2**10, K=2**10, kB=2**10, KB=2**10,
             M=2**20, MB=2**20, G=2**30, GB=2**30,
             T=2**40, TB=2**40)
    for k,v in s.items():
        if x.endswith(k):
            return int(float(x[:-len(k)])*v)
    sys.stderr.write("Memory size should be written with units, for example 1024M\n")
    sys.exit(-1)

def pidusername(pid):
    return src.username(src.piduser(pid))

def showamount(a, total):
    if options.abbreviate:
        return units(a * 1024)
    elif options.percent:
        if total == 0:
            return 'N/A'
        return "%.2f%%" % (100.0 * a / total)
    return a

def filter(opt, arg, *sources):
    if not opt:
        return False

    for f in sources:
        if re.search(opt, f(arg)):
            return False
    return True

def pidtotals(pid):
    maps = pidmaps(pid)
    t = dict(size=0, rss=0, pss=0, shared_clean=0, shared_dirty=0,
             private_clean=0, private_dirty=0, referenced=0, swap=0)
    for m in maps.iterkeys():
        for k in t:
            t[k] += maps[m].get(k, 0)

    t['uss'] = t['private_clean'] + t['private_dirty']
    t['maps'] = len(maps)

    return t

def processtotals(pids):
    totals = {}
    for pid in pids:
        if (filter(options.processfilter, pid, src.pidname, src.pidcmd) or
                filter(options.userfilter, pid, pidusername)):
            continue
        try:
            p = pidtotals(pid)
            if p['maps'] != 0:
                totals[pid] = p
        except:
            continue
    return totals

def showpids():
    p = src.pids()
    pt = processtotals(p)

    def showuser(p):
        if options.numeric:
            return src.piduser(p)
        return pidusername(p)

    fields = dict(
        pid=('PID', lambda n: n, '% 5s', lambda x: len(pt),
             'process ID'),
        user=('User', showuser, '%-8s', lambda x: len(dict.fromkeys(x)),
              'owner of process'),
        name=('Name', src.pidname, '%-24.24s', None,
              'name of process'),
        command=('Command', src.pidcmd, '%-27.27s', None,
                 'process command line'),
        maps=('Maps',lambda n: pt[n]['maps'], '% 5s', sum,
              'total number of mappings'),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum,
              'amount of swap space consumed (ignoring sharing)'),
        uss=('USS', lambda n: pt[n]['uss'], '% 8a', sum,
             'unique set size'),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum,
             'resident set size (ignoring sharing)'),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum,
             'proportional set size (including sharing)'),
        vss=('VSS', lambda n: pt[n]['size'], '% 8a', sum,
             'virtual set size (total virtual memory mapped)'),
        )
    columns = options.columns or 'pid user command swap uss pss rss'

    showtable(pt.keys(), fields, columns.split(), options.sort or 'pss')

def maptotals(pids):
    totals = {}
    for pid in pids:
        if (filter(options.processfilter, pid, src.pidname, src.pidcmd) or
                filter(options.userfilter, pid, pidusername)):
            continue
        try:
            maps = pidmaps(pid)
            seen = {}
            for m in maps.iterkeys():
                name = maps[m]['name']
                if name not in totals:
                    t = dict(size=0, rss=0, pss=0, shared_clean=0,
                             shared_dirty=0, private_clean=0, count=0,
                             private_dirty=0, referenced=0, swap=0, pids=0)
                else:
                    t = totals[name]

                for k in t:
                    t[k] += maps[m].get(k, 0)
                t['count'] += 1
                if name not in seen:
                    t['pids'] += 1
                    seen[name] = 1
                totals[name] = t
        except EnvironmentError:
            continue
    return totals

def showmaps():
    p = src.pids()
    pt = maptotals(p)

    fields = dict(
        map=('Map', lambda n: n, '%-40.40s', len,
             'mapping name'),
        count=('Count', lambda n: pt[n]['count'], '% 5s', sum,
               'number of mappings found'),
        pids=('PIDs', lambda n: pt[n]['pids'], '% 5s', sum,
              'number of PIDs using mapping'),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum,
              'amount of swap space consumed (ignoring sharing)'),
        uss=('USS', lambda n: pt[n]['private_clean']
             + pt[n]['private_dirty'], '% 8a', sum,
             'unique set size'),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum,
             'resident set size (ignoring sharing)'),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum,
             'proportional set size (including sharing)'),
        vss=('VSS', lambda n: pt[n]['size'], '% 8a', sum,
             'virtual set size (total virtual address space mapped)'),
        avgpss=('AVGPSS', lambda n: int(1.0 * pt[n]['pss']/pt[n]['pids']),
                '% 8a', sum,
                'average PSS per PID'),
        avguss=('AVGUSS', lambda n: int(1.0 * pt[n]['uss']/pt[n]['pids']),
                '% 8a', sum,
                'average USS per PID'),
        avgrss=('AVGRSS', lambda n: int(1.0 * pt[n]['rss']/pt[n]['pids']),
                '% 8a', sum,
                'average RSS per PID'),
        )
    columns = options.columns or 'map pids avgpss pss'

    showtable(pt.keys(), fields, columns.split(), options.sort or 'pss')

def usertotals(pids):
    totals = {}
    for pid in pids:
        if (filter(options.processfilter, pid, src.pidname, src.pidcmd) or
                filter(options.userfilter, pid, pidusername)):
            continue
        try:
            maps = pidmaps(pid)
            if len(maps) == 0:
                continue
        except EnvironmentError:
            continue
        user = src.piduser(pid)
        if user not in totals:
            t = dict(size=0, rss=0, pss=0, shared_clean=0,
                     shared_dirty=0, private_clean=0, count=0,
                     private_dirty=0, referenced=0, swap=0)
        else:
            t = totals[user]

        for m in maps.iterkeys():
            for k in t:
                t[k] += maps[m].get(k, 0)

        t['count'] += 1
        totals[user] = t
    return totals

def showusers():
    p = src.pids()
    pt = usertotals(p)

    def showuser(u):
        if options.numeric:
            return u
        return src.username(u)

    fields = dict(
        user=('User', showuser, '%-8s', None,
              'user name or ID'),
        count=('Count', lambda n: pt[n]['count'], '% 5s', sum,
               'number of processes'),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum,
              'amount of swapspace consumed (ignoring sharing)'),
        uss=('USS', lambda n: pt[n]['private_clean']
             + pt[n]['private_dirty'], '% 8a', sum,
             'unique set size'),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum,
             'resident set size (ignoring sharing)'),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum,
             'proportional set size (including sharing)'),
        vss=('VSS', lambda n: pt[n]['pss'], '% 8a', sum,
             'virtual set size (total virtual memory mapped)'),
        )
    columns = options.columns or 'user count swap uss pss rss'

    showtable(pt.keys(), fields, columns.split(), options.sort or 'pss')

def showsystem():
    t = totalmem()
    ki = kernelsize()
    m = memory()

    mt = m['memtotal']
    f = m['memfree']

    # total amount used by hardware
    fh = max(t - mt - ki, 0)

    # total amount mapped into userspace (ie mapped an unmapped pages)
    u = m['anonpages'] + m['mapped']

    # total amount allocated by kernel not for userspace
    kd = mt - f - u

    # total amount in kernel caches
    kdc = m['buffers'] + m['sreclaimable'] + (m['cached'] - m['mapped'])

    l = [("firmware/hardware", fh, 0),
         ("kernel image", ki, 0),
         ("kernel dynamic memory", kd, kdc),
         ("userspace memory", u, m['mapped']),
         ("free memory", f, f)]

    fields = dict(
        order=('Order', lambda n: n, '% 1s', lambda x: '',
             'hierarchical order'),
        area=('Area', lambda n: l[n][0], '%-24s', lambda x: '',
             'memory area'),
        used=('Used', lambda n: l[n][1], '%10a', sum,
              'area in use'),
        cache=('Cache', lambda n: l[n][2], '%10a', sum,
              'area used as reclaimable cache'),
        noncache=('Noncache', lambda n: l[n][1] - l[n][2], '%10a', sum,
              'area not reclaimable'))

    columns = options.columns or 'area used cache noncache'
    showtable(range(len(l)), fields, columns.split(), options.sort or 'order')

def showfields(fields, f):
    if f != list:
        print "unknown field", f
    print "known fields:"
    for l in sorted(fields.keys()):
        print "%-8s %s" % (l, fields[l][-1])

def showtable(rows, fields, columns, sort):
    header = ""
    format = ""
    formatter = []

    if sort not in fields:
            showfields(fields, sort)
            sys.exit(-1)

    if options.pie:
        columns.append(options.pie)
    if options.bar:
        columns.append(options.bar)

    mt = totalmem()
    st = memory()['swaptotal']

    for n in columns:
        if n not in fields:
            showfields(fields, n)
            sys.exit(-1)

        f = fields[n][2]
        if 'a' in f:
            if n == 'swap':
                formatter.append(lambda x: showamount(x, st))
            else:
                formatter.append(lambda x: showamount(x, mt))
            f = f.replace('a', 's')
        else:
            formatter.append(lambda x: x)
        format += f + " "
        header += f % fields[n][0] + " "

    l = []
    for n in rows:
        r = [fields[c][1](n) for c in columns]
        l.append((fields[sort][1](n), r))

    l.sort(reverse=bool(options.reverse))

    if options.pie:
        showpie(l, sort)
        return
    elif options.bar:
        showbar(l, columns, sort)
        return

    if not options.no_header:
        print header

    for k,r in l:
        print format % tuple([f(v) for f,v in zip(formatter, r)])

    if options.totals:
        # totals
        t = []
        for c in columns:
            f = fields[c][3]
            if f:
                t.append(f([fields[c][1](n) for n in rows]))
            else:
                t.append("")

        print "-" * len(header)
        print format % tuple([f(v) for f,v in zip(formatter, t)])

def showpie(l, sort):
    try:
        import pylab
    except ImportError:
        sys.stderr.write("pie chart requires matplotlib\n")
        sys.exit(-1)

    if (l[0][0] < l[-1][0]):
        l.reverse()

    labels = [r[1][-1] for r in l]
    values = [r[0] for r in l] # sort field

    tm = totalmem()
    s = sum(values)
    unused = tm - s
    t = 0
    while values and (t + values[-1] < (tm * .02) or
                      values[-1] < (tm * .005)):
        t += values.pop()
        labels.pop()

    if t:
        values.append(t)
        labels.append('other')

    explode = [0] * len(values)
    if unused > 0:
        values.insert(0, unused)
        labels.insert(0, 'unused')
        explode.insert(0, .05)

    pylab.figure(1, figsize=(6,6))
    ax = pylab.axes([0.1, 0.1, 0.8, 0.8])
    pylab.pie(values, explode = explode, labels=labels,
              autopct="%.2f%%", shadow=True)
    pylab.title('%s by %s' % (options.pie, sort))
    pylab.show()

def showbar(l, columns, sort):
    try:
        import pylab, numpy
    except ImportError:
        sys.stderr.write("bar chart requires matplotlib\n")
        sys.exit(-1)

    if (l[0][0] < l[-1][0]):
        l.reverse()

    rc = []
    key = []
    for n in range(len(columns) - 1):
        try:
            if columns[n] in 'pid user group'.split():
                continue
            float(l[0][1][n])
            rc.append(n)
            key.append(columns[n])
        except:
            pass

    width = 1.0 / (len(rc) + 1)
    offset = width / 2

    def gc(n):
        return 'bgrcmyw'[n % 7]

    pl = []
    ind = numpy.arange(len(l))
    for n in xrange(len(rc)):
        pl.append(pylab.bar(ind + offset + width * n,
                             [x[1][rc[n]] for x in l], width, color=gc(n)))

    #plt.xticks(ind + .5, )
    pylab.gca().set_xticks(ind + .5)
    pylab.gca().set_xticklabels([x[1][-1] for x in l], rotation=45)
    pylab.legend([p[0] for p in pl], key)
    pylab.show()


parser = optparse.OptionParser("%prog [options]")
parser.add_option("-H", "--no-header", action="store_true",
                  help="disable header line")
parser.add_option("-c", "--columns", type="str",
                  help="columns to show")
parser.add_option("-t", "--totals", action="store_true",
                  help="show totals")

parser.add_option("-R", "--realmem", type="str",
                  help="amount of physical RAM")
parser.add_option("-K", "--kernel", type="str",
                  help="path to kernel image")

parser.add_option("-m", "--mappings", action="store_true",
                  help="show mappings")
parser.add_option("-u", "--users", action="store_true",
                  help="show users")
parser.add_option("-w", "--system", action="store_true",
                  help="show whole system")

parser.add_option("-P", "--processfilter", type="str",
                  help="process filter regex")
parser.add_option("-M", "--mapfilter", type="str",
                  help="map filter regex")
parser.add_option("-U", "--userfilter", type="str",
                  help="user filter regex")

parser.add_option("-n", "--numeric", action="store_true",
                  help="numeric output")
parser.add_option("-s", "--sort", type="str",
                  help="field to sort on")
parser.add_option("-r", "--reverse", action="store_true",
                  help="reverse sort")

parser.add_option("-p", "--percent", action="store_true",
                  help="show percentage")
parser.add_option("-k", "--abbreviate", action="store_true",
                  help="show unit suffixes")

parser.add_option("", "--pie", type='str',
                  help="show pie graph")
parser.add_option("", "--bar", type='str',
                  help="show bar graph")

parser.add_option("-S", "--source", type="str",
                  help="/proc data source")


defaults = {}
parser.set_defaults(**defaults)
(options, args) = parser.parse_args()

try:
    src = tardata(options.source)
except:
    src = procdata(options.source)

try:
    if options.mappings:
        showmaps()
    elif options.users:
        showusers()
    elif options.system:
        showsystem()
    else:
        showpids()
except IOError, e:
    if e.errno == errno.EPIPE:
        pass
except KeyboardInterrupt:
    pass
